from utils import *
import sys
from math import inf
import time
from tabulate import tabulate
from sklearn.cluster import KMeans

for trial in range(20):
    m = 7
    K = 10
    M = 1.0
    a = 3.0
    b = 3.0
    low = 5.0
    high = 8.0
    # high = 10.0
    
    mu0 = generate_mu(m, low, high, a, b)
    mu1 = generate_mu(m, low, high, a, b)
    mu2 = generate_mu(m, low, high, a, b)
    mu3 = generate_mu(m, low, high, a, b)
    log = {'mu0' : mu0, 'mu1' : mu1, 'mu2' : mu2, 'mu3' : mu3}
    N_tot = 500
    var = 3.0
    # Generate attacker data
    theta_1, theta_2 = generate_theta_normal(mu0, var * np.eye(m), mu1, var * np.eye(m), mu2, var * np.eye(m), mu3, var * np.eye(m), N_tot)
    # tot_theta = np.concatenate([theta_1, theta_2], axis=2)
    full_theta = np.concatenate([theta_1, theta_2], axis=2)
    print("Mu0 : ", mu0)
    print("Mu1 : ", mu1)
    print("Mu2 : ", mu2)
    print("Mu3 : ", mu3)

    xi = 1e6
    A, b, C, d, tL, tU = compute_params(m, K, N_tot, numerator, denominator, full_theta)
    w = np.ones(N_tot + 1)
    w[-1] = 0
    tot_z_dro = FCP_DRO(m, K, N_tot, N_tot, M, tL, tU, A, b, C, d, w, xi)
    tot_opt = utility_robust(full_theta, numerator, denominator, m, tot_z_dro, N_tot, xi, w)
    print(tot_opt)

    all_opt = []
    all_times = []
    all_losses = []
    N = 2
    while (N <= 20):
        tot_theta = np.array(full_theta)
        tot_theta = tot_theta.reshape(N_tot,-1)
        kmeans = KMeans(n_clusters=N).fit(tot_theta)
        all_losses.append(kmeans.inertia_)
        cluster_cent = kmeans.cluster_centers_
        cluster_cent = cluster_cent.reshape(N, m, 4)
        Y = kmeans.predict(tot_theta)
        s = np.zeros(N+1)
        for i in range(len(Y)):
            s[Y[i]] += 1
        A, b, C, d, tL, tU = compute_params(m, K, N, numerator, denominator, cluster_cent)
        sta = time.perf_counter()
        z_dro = FCP_DRO(m, K, N, N_tot, M, tL, tU, A, b, C, d, s, xi)
        opt = utility_robust(full_theta, numerator, denominator, m, z_dro, N_tot, xi, w)
        fin = time.perf_counter()
        # diff = np.sqrt(sum((z_dro - tot_z_dro)**2))
        print(tot_opt)
        diff = 100 * (tot_opt - opt) / tot_opt
        print("N : ", N, " diff : ", diff)
        all_opt.append(opt)
        all_times.append(fin-sta)
        N += 1

    results_location = './simulations/SSG_global_optimality_{}_trial_{}.npy'.format(m, trial)
    log['all_opt'] = all_opt
    log['all_times'] = all_times
    log['tot_opt'] = tot_opt
    log['all_losses'] = all_losses
    np.save(results_location, log) 
